{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 从零实现llama3\n",
    "在这个文件中，我从零开始实现了llama3，一次一个张量和矩阵乘法。\n",
    "<br>\n",
    "此外，我将直接从Meta提供的llama3模型文件中加载张量，在运行此文件之前，您需要下载模型的权重。\n",
    "以下是下载权重的官方链接：https://llama.meta.com/llama-downloads/\n",
    "\n",
    "<div>\n",
    "    <img src=\"images/archi.png\"/>\n",
    "</div>\n",
    "\n",
    "## 分词器\n",
    "我不会实现一个bpe分词器（但Andrej Karpathy有一个非常简洁的实现）\n",
    "<br>\n",
    "链接到他的实现：https://github.com/karpathy/minbpe\n",
    "\n",
    "<div>\n",
    "    <img src=\"images/karpathyminbpe.png\" width=\"600\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "from pathlib import Path\n",
    "import tiktoken\n",
    "from tiktoken.load import load_tiktoken_bpe\n",
    "import torch\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "tokenizer_path = \"Meta-Llama-3-8B/tokenizer.model\"\n",
    "special_tokens = [\n",
    "            \"<|begin_of_text|>\",\n",
    "            \"<|end_of_text|>\",\n",
    "            \"<|reserved_special_token_0|>\",\n",
    "            \"<|reserved_special_token_1|>\",\n",
    "            \"<|reserved_special_token_2|>\",\n",
    "            \"<|reserved_special_token_3|>\",\n",
    "            \"<|start_header_id|>\",\n",
    "            \"<|end_header_id|>\",\n",
    "            \"<|reserved_special_token_4|>\",\n",
    "            \"<|eot_id|>\",  # end of turn\n",
    "        ] + [f\"<|reserved_special_token_{i}|>\" for i in range(5, 256 - 5)]\n",
    "mergeable_ranks = load_tiktoken_bpe(tokenizer_path)\n",
    "tokenizer = tiktoken.Encoding(\n",
    "    name=Path(tokenizer_path).name,\n",
    "    pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+\",\n",
    "    mergeable_ranks=mergeable_ranks,\n",
    "    special_tokens={token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)},\n",
    ")\n",
    "\n",
    "tokenizer.decode(tokenizer.encode(\"hello world!\"))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 读取模型文件\n",
    "通常，读取模型文件取决于模型类的编写方式以及内部的变量名。\n",
    "<br>\n",
    "但由于我们从零开始实现llama3，所以我们将一次读取一个张量。\n",
    "<div>\n",
    "    <img src=\"images/model.png\" width=\"600\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = torch.load(\"Meta-Llama-3-8B/consolidated.00.pth\")\n",
    "print(json.dumps(list(model.keys())[:20], indent=4))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "with open(\"Meta-Llama-3-8B/params.json\", \"r\") as f:\n",
    "    config = json.load(f)\n",
    "config"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 我们使用此配置来推断有关模型的细节，例如\n",
    "1. 该模型有32个transformer层\n",
    "2. 每个多头注意力块有32个头\n",
    "3. 词汇量大小等等"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "dim = config[\"dim\"]\n",
    "n_layers = config[\"n_layers\"]\n",
    "n_heads = config[\"n_heads\"]\n",
    "n_kv_heads = config[\"n_kv_heads\"]\n",
    "vocab_size = config[\"vocab_size\"]\n",
    "multiple_of = config[\"multiple_of\"]\n",
    "ffn_dim_multiplier = config[\"ffn_dim_multiplier\"]\n",
    "norm_eps = config[\"norm_eps\"]\n",
    "rope_theta = torch.tensor(config[\"rope_theta\"])"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 将文本转换为标记\n",
    "这里我们使用tiktoken（我认为是OpenAI的一个库）作为分词器\n",
    "<div>\n",
    "    <img src=\"images/tokens.png\" width=\"600\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "prompt = \"the answer to the ultimate question of life, the universe, and everything is \"\n",
    "tokens = [128000] + tokenizer.encode(prompt)\n",
    "print(tokens)\n",
    "tokens = torch.tensor(tokens)\n",
    "prompt_split_as_tokens = [tokenizer.decode([token.item()]) for token in tokens]\n",
    "print(prompt_split_as_tokens)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 将标记转换为它们的嵌入\n",
    "抱歉，这是代码库中唯一使用内置神经网络模块的部分。\n",
    "<br>\n",
    "无论如何，我们的 [17x1] 标记现在变为 [17x4096]，即17个长度为4096的嵌入（每个标记一个嵌入）。\n",
    "<br>\n",
    "<br>\n",
    "注意：跟踪这些形状，会让理解所有内容变得更容易。\n",
    "\n",
    "<div>\n",
    "    <img src=\"images/embeddings.png\" width=\"600\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "embedding_layer = torch.nn.Embedding(vocab_size, dim)\n",
    "embedding_layer.weight.data.copy_(model[\"tok_embeddings.weight\"])\n",
    "token_embeddings_unnormalized = embedding_layer(tokens).to(torch.bfloat16)\n",
    "token_embeddings_unnormalized.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 然后我们使用rms归一化对嵌入进行标准化\n",
    "请注意，经过此步骤后，形状不会改变，值只是被归一化了。\n",
    "<br>\n",
    "需要记住的是，我们需要一个norm_eps（来自配置），因为我们不希望不小心将rms设置为0并导致除以0。\n",
    "<br>\n",
    "这里是公式：\n",
    "<div>\n",
    "    <img src=\"images/rms.png\" width=\"600\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# def rms_norm(tensor, norm_weights):\n",
    "#     rms = (tensor.pow(2).mean(-1, keepdim=True) + norm_eps)**0.5\n",
    "#     return tensor * (norm_weights / rms)\n",
    "def rms_norm(tensor, norm_weights):\n",
    "    return (tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + norm_eps)) * norm_weights"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 构建Transformer的第一层\n",
    "\n",
    "### 归一化\n",
    "你会看到我从模型字典中访问 layer.0（这是第一层）。\n",
    "<br>\n",
    "总之，经过归一化处理后，我们的形状仍然是 [17x4096]，与嵌入相同，但已归一化。\n",
    "\n",
    "<div>\n",
    "    <img src=\"images/norm.png\" width=\"600\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "token_embeddings = rms_norm(token_embeddings_unnormalized, model[\"layers.0.attention_norm.weight\"])\n",
    "token_embeddings.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 从零实现注意力机制\n",
    "让我们加载Transformer第一层的注意力头\n",
    "\n",
    "<div>\n",
    "    <img src=\"images/qkv.png\" width=\"600\"/>\n",
    "</div>\n",
    "\n",
    "<br>\n",
    "\n",
    "&gt; 当我们从模型中加载query、key、value和output向量时，我们注意到它们的形状分别是[4096x4096]、[1024x4096]、[1024x4096]、[4096x4096]\n",
    "<br>\n",
    "&gt; 乍一看这很奇怪，因为理想情况下我们希望为每个head单独获取q、k、v和o\n",
    "<br>\n",
    "&gt; 代码的作者将它们捆绑在一起是因为这样可以更容易并行化注意力头的乘法运算。\n",
    "<br>\n",
    "&gt; 我现在要将它们解开..."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "print(\n",
    "    model[\"layers.0.attention.wq.weight\"].shape,\n",
    "    model[\"layers.0.attention.wk.weight\"].shape,\n",
    "    model[\"layers.0.attention.wv.weight\"].shape,\n",
    "    model[\"layers.0.attention.wo.weight\"].shape\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 解开查询\n",
    "在接下来的部分中，我们将解开多个注意力头中的查询，结果的形状是 [32x128x4096]。\n",
    "<br><br>\n",
    "这里，32 是llama3中的注意力头数，128 是查询向量的大小，4096 是标记嵌入的大小。"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "q_layer0 = model[\"layers.0.attention.wq.weight\"]\n",
    "head_dim = q_layer0.shape[0] // n_heads\n",
    "q_layer0 = q_layer0.view(n_heads, head_dim, dim)\n",
    "q_layer0.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 我将实现第一层的第一个头\n",
    "这里我首先访问第一层第一个头的查询权重矩阵，这个查询权重矩阵的大小是 [128x4096]。"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "q_layer0_head0 = q_layer0[0]\n",
    "q_layer0_head0.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 现在我们将查询权重与标记嵌入相乘，得到该标记的查询\n",
    "这里你可以看到结果的形状是 [17x128]，这是因为我们有17个标记，并且对于每个标记都有一个长度为128的查询。\n",
    "<div>\n",
    "    <img src=\"images/q_per_token.png\" width=\"600\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "q_per_token = torch.matmul(token_embeddings, q_layer0_head0.T)\n",
    "q_per_token.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 位置编码\n",
    "我们现在已经为提示中的每个标记生成了查询向量，但如果你仔细想想——单独的查询向量并不了解它在提示中的位置。\n",
    "<br><br>\n",
    "示例查询：“the answer to the ultimate question of life, the universe, and everything is ”\n",
    "<br><br>\n",
    "在我们的提示中，\"the\" 出现了三次，我们需要基于它们在查询中的位置，让这三个 \"the\" 标记的查询向量有不同的查询向量（每个大小为 [1x128]）。我们使用RoPE（旋转位置嵌入）来执行这些旋转。\n",
    "<br><br>\n",
    "### RoPE\n",
    "观看这个视频（我也是通过这个视频理解的）来了解其中的数学原理。\n",
    "https://www.youtube.com/watch?v=o29P0Kpobz0&t=530s\n",
    "\n",
    "<div>\n",
    "    <img src=\"images/rope.png\" width=\"600\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)\n",
    "q_per_token_split_into_pairs.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在上一步中，我们将查询向量分成了对，并对每对应用一个旋转角度偏移！\n",
    "<br><br>\n",
    "现在我们得到一个大小为 [17x64x2] 的向量，这是128长度的查询向量被分成了64对，每个标记对应提示中的一个查询！这64对中的每一对都将通过 m*(theta) 旋转，其中 m 是我们为其旋转查询的标记位置！"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用复数的点积来旋转向量\n",
    "<div>\n",
    "    <img src=\"images/freq_cis.png\" width=\"600\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "zero_to_one_split_into_64_parts = torch.tensor(range(64))/64\n",
    "zero_to_one_split_into_64_parts"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "freqs = 1.0 / (rope_theta ** zero_to_one_split_into_64_parts)\n",
    "freqs"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "freqs_for_each_token = torch.outer(torch.arange(17), freqs)\n",
    "freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token)\n",
    "freqs_cis.shape\n",
    "\n",
    "# viewing tjhe third row of freqs_cis\n",
    "value = freqs_cis[3]\n",
    "plt.figure()\n",
    "for i, element in enumerate(value[:17]):\n",
    "    plt.plot([0, element.real], [0, element.imag], color='blue', linewidth=1, label=f\"Index: {i}\")\n",
    "    plt.annotate(f\"{i}\", xy=(element.real, element.imag), color='red')\n",
    "plt.xlabel('Real')\n",
    "plt.ylabel('Imaginary')\n",
    "plt.title('Plot of one row of freqs_cis')\n",
    "plt.show()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 现在我们为每个标记的查询元素得到了一个复数（角度变化向量）\n",
    "我们可以将我们的查询（之前分成的对）转换为复数，然后通过点积运算来根据位置旋转查询。\n",
    "<br>\n",
    "说实话，这个过程真的很美妙 :)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)\n",
    "q_per_token_as_complex_numbers.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "q_per_token_as_complex_numbers_rotated = q_per_token_as_complex_numbers * freqs_cis\n",
    "q_per_token_as_complex_numbers_rotated.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 得到旋转后的向量之后\n",
    "我们可以通过将复数再次视为实数来恢复查询向量为对的形式。"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers_rotated)\n",
    "q_per_token_split_into_pairs_rotated.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "旋转后的对现在已合并，我们现在得到了一个新的查询向量（旋转后的查询向量），其形状为 [17x128]，其中17是标记的数量，128是查询向量的维度。"
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)\n",
    "q_per_token_rotated.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 键（几乎和查询相同）\n",
    "<div>\n",
    "    <img src=\"images/keys.png\" width=\"600px\"/>\n",
    "</div>\n",
    "我超级懒，所以我不会详细讲解键的数学部分，你只需要记住以下几点：\n",
    "<br>\n",
    "&gt; 键生成的键向量的维度也是128\n",
    "<br>\n",
    "&gt; 键的权重数量只有查询的1/4，这是因为键的权重在一次计算中会在4个头之间共享，从而减少所需的计算量\n",
    "<br>\n",
    "&gt; 键同样会通过旋转添加位置信息，和查询一样，出于相同的原因"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "k_layer0 = model[\"layers.0.attention.wk.weight\"]\n",
    "k_layer0 = k_layer0.view(n_kv_heads, k_layer0.shape[0] // n_kv_heads, dim)\n",
    "k_layer0.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "k_layer0_head0 = k_layer0[0]\n",
    "k_layer0_head0.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "k_per_token = torch.matmul(token_embeddings, k_layer0_head0.T)\n",
    "k_per_token.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)\n",
    "k_per_token_split_into_pairs.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)\n",
    "k_per_token_as_complex_numbers.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)\n",
    "k_per_token_split_into_pairs_rotated.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)\n",
    "k_per_token_rotated.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 现在我们已经得到了每个标记的查询和键的旋转值\n",
    "<div>\n",
    "    <img src=\"images/keys0.png\" width=\"600px\"/>\n",
    "</div>\n",
    "每个查询和键的形状现在都是 [17x128]。\n",
    "\n",
    "## 在下一步中，我们将对查询和键矩阵进行相乘\n",
    "这样做将为我们提供每个标记与其他标记之间的得分映射。\n",
    "<br>\n",
    "这个得分描述了每个标记的查询与每个标记的键之间的关联程度。\n",
    "这就是自注意力机制 :)\n",
    "<br>\n",
    "注意力得分矩阵（qk_per_token）的形状是 [17x17]，其中17是提示中的标记数量。\n",
    "\n",
    "<div>\n",
    "    <img src=\"images/qkmatmul.png\" width=\"600px\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(head_dim)**0.5\n",
    "qk_per_token.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 现在我们需要对查询-键得分进行掩码处理\n",
    "在llama3的训练过程中，未来标记的qk得分会被掩码。\n",
    "<br>\n",
    "为什么？因为在训练过程中，我们只能使用过去的标记来预测当前标记。\n",
    "<br>\n",
    "因此，在推理过程中，我们将未来的标记得分设置为零。\n",
    "\n",
    "<div>\n",
    "    <img src=\"images/mask.png\" width=\"600px\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "def display_qk_heatmap(qk_per_token):\n",
    "    _, ax = plt.subplots()\n",
    "    im = ax.imshow(qk_per_token.to(float).detach(), cmap='viridis')\n",
    "    ax.set_xticks(range(len(prompt_split_as_tokens)))\n",
    "    ax.set_yticks(range(len(prompt_split_as_tokens)))\n",
    "    ax.set_xticklabels(prompt_split_as_tokens)\n",
    "    ax.set_yticklabels(prompt_split_as_tokens)\n",
    "    ax.figure.colorbar(im, ax=ax)\n",
    "    \n",
    "display_qk_heatmap(qk_per_token)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "mask = torch.full((len(tokens), len(tokens)), float(\"-inf\"), device=tokens.device)\n",
    "mask = torch.triu(mask, diagonal=1)\n",
    "mask"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "qk_per_token_after_masking = qk_per_token + mask\n",
    "display_qk_heatmap(qk_per_token_after_masking)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div>\n",
    "    <img src=\"images/softmax.png\" width=\"600px\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)\n",
    "display_qk_heatmap(qk_per_token_after_masking_after_softmax)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 值（注意力机制的接近尾声）\n",
    "\n",
    "<div>\n",
    "    <img src=\"images/value.png\" width=\"600px\"/>\n",
    "</div>\n",
    "\n",
    "这些得分（0-1）用于确定每个标记使用多少值矩阵。\n",
    "<br>\n",
    "&gt; 和键一样，值的权重也是在每4个注意力头之间共享的（以节省计算）。\n",
    "<br>\n",
    "&gt; 因此，下面值权重矩阵的形状为 [8x128x4096]。"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "v_layer0 = model[\"layers.0.attention.wv.weight\"]\n",
    "v_layer0 = v_layer0.view(n_kv_heads, v_layer0.shape[0] // n_kv_heads, dim)\n",
    "v_layer0.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "第一层第一个头的值权重矩阵如下所示："
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "v_layer0_head0 = v_layer0[0]\n",
    "v_layer0_head0.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 值向量\n",
    "<div>\n",
    "    <img src=\"images/v0.png\" width=\"600px\"/>\n",
    "</div>\n",
    "我们现在使用值权重来获取每个标记的注意力值，其大小为 [17x128]，其中17是提示中的标记数量，128是每个标记的值向量的维度。"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "v_per_token = torch.matmul(token_embeddings, v_layer0_head0.T)\n",
    "v_per_token.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 注意力\n",
    "<div>\n",
    "    <img src=\"images/attention.png\" width=\"600px\"/>\n",
    "</div>\n",
    "与每个标记的值相乘后得到的注意力向量的形状是 [17x128]。"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)\n",
    "qkv_attention.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 多头注意力\n",
    "<div>\n",
    "    <img src=\"images/heads.png\" width=\"600px\"/>\n",
    "</div>\n",
    "我们现在得到了第一层和第一个头的注意力值。\n",
    "<br>\n",
    "接下来，我将运行一个循环，并对第一层中的每个头执行与上面相同的数学运算。"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "qkv_attention_store = []\n",
    "\n",
    "for head in range(n_heads):\n",
    "    q_layer0_head = q_layer0[head]\n",
    "    k_layer0_head = k_layer0[head//4] # key weights are shared across 4 heads\n",
    "    v_layer0_head = v_layer0[head//4] # value weights are shared across 4 heads\n",
    "    q_per_token = torch.matmul(token_embeddings, q_layer0_head.T)\n",
    "    k_per_token = torch.matmul(token_embeddings, k_layer0_head.T)\n",
    "    v_per_token = torch.matmul(token_embeddings, v_layer0_head.T)\n",
    "\n",
    "    q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)\n",
    "    q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)\n",
    "    q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis[:len(tokens)])\n",
    "    q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)\n",
    "\n",
    "    k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)\n",
    "    k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)\n",
    "    k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis[:len(tokens)])\n",
    "    k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)\n",
    "\n",
    "    qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(128)**0.5\n",
    "    mask = torch.full((len(tokens), len(tokens)), float(\"-inf\"), device=tokens.device)\n",
    "    mask = torch.triu(mask, diagonal=1)\n",
    "    qk_per_token_after_masking = qk_per_token + mask\n",
    "    qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)\n",
    "    qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)\n",
    "    qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)\n",
    "    qkv_attention_store.append(qkv_attention)\n",
    "\n",
    "len(qkv_attention_store)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#\n",
    "<div>\n",
    "    <img src=\"images/stacked.png\" width=\"600px\"/>\n",
    "</div>\n",
    "我们现在得到了第一层中32个头的qkv_attention矩阵，接下来我要将所有的注意力得分合并成一个大小为 [17x4096] 的大矩阵。\n",
    "<br>\n",
    "我们快要完成了 :)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)\n",
    "stacked_qkv_attention.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 权重矩阵，最后的步骤之一\n",
    "<div>\n",
    "    <img src=\"images/weightmatrix.png\" width=\"600px\"/>\n",
    "</div>\n",
    "对于第0层注意力机制，最后要做的事情之一是将其与权重矩阵相乘。"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "w_layer0 = model[\"layers.0.attention.wo.weight\"]\n",
    "w_layer0.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "### 这是一个简单的线性层，所以我们只需进行矩阵相乘"
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "embedding_delta = torch.matmul(stacked_qkv_attention, w_layer0.T)\n",
    "embedding_delta.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#\n",
    "<div>\n",
    "    <img src=\"images/afterattention.png\" width=\"600px\"/>\n",
    "</div>\n",
    "我们现在得到了注意力机制之后的嵌入值变化，它应该被加到原始的标记嵌入上。"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "embedding_after_edit = token_embeddings_unnormalized + embedding_delta\n",
    "embedding_after_edit.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 我们对嵌入增量进行归一化，然后通过一个前馈神经网络\n",
    "<div>\n",
    "    <img src=\"images/norm_after.png\" width=\"600px\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "embedding_after_edit_normalized = rms_norm(embedding_after_edit, model[\"layers.0.ffn_norm.weight\"])\n",
    "embedding_after_edit_normalized.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载前馈网络的权重并实现前馈网络\n",
    "<div>\n",
    "    <img src=\"images/swiglu.png\" width=\"600px\"/>\n",
    "</div>\n",
    "在llama3中，他们使用了SwiGLU前馈网络，这种网络架构在模型需要时非常擅长添加非线性元素。\n",
    "<br>\n",
    "在如今的大型语言模型中，使用这种前馈网络架构已经相当普遍。"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "w1 = model[\"layers.0.feed_forward.w1.weight\"]\n",
    "w2 = model[\"layers.0.feed_forward.w2.weight\"]\n",
    "w3 = model[\"layers.0.feed_forward.w3.weight\"]\n",
    "output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T)\n",
    "output_after_feedforward.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 我们终于得到了第一层之后每个标记的新编辑嵌入\n",
    "只剩下31层就完成了（只需一个循环）。\n",
    "<br>\n",
    "你可以想象这个编辑后的嵌入包含了在第一层中所有查询的信息。\n",
    "<br>\n",
    "现在每一层都会对所提出的问题进行越来越复杂的编码，直到我们得到一个可以完全推测下一个标记的嵌入。"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "layer_0_embedding = embedding_after_edit+output_after_feedforward\n",
    "layer_0_embedding.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 天啊，一切都在一起\n",
    "<div>\n",
    "    <img src=\"images/god.png\" width=\"600px\"/>\n",
    "</div>\n",
    "没错，就是这样。我们之前做的所有事情，现在一次性完成，针对每一层。\n",
    "<br>\n",
    "\n",
    "# 祝你阅读愉快 :)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "final_embedding = token_embeddings_unnormalized\n",
    "for layer in range(n_layers):\n",
    "    qkv_attention_store = []\n",
    "    layer_embedding_norm = rms_norm(final_embedding, model[f\"layers.{layer}.attention_norm.weight\"])\n",
    "    q_layer = model[f\"layers.{layer}.attention.wq.weight\"]\n",
    "    q_layer = q_layer.view(n_heads, q_layer.shape[0] // n_heads, dim)\n",
    "    k_layer = model[f\"layers.{layer}.attention.wk.weight\"]\n",
    "    k_layer = k_layer.view(n_kv_heads, k_layer.shape[0] // n_kv_heads, dim)\n",
    "    v_layer = model[f\"layers.{layer}.attention.wv.weight\"]\n",
    "    v_layer = v_layer.view(n_kv_heads, v_layer.shape[0] // n_kv_heads, dim)\n",
    "    w_layer = model[f\"layers.{layer}.attention.wo.weight\"]\n",
    "    for head in range(n_heads):\n",
    "        q_layer_head = q_layer[head]\n",
    "        k_layer_head = k_layer[head//4]\n",
    "        v_layer_head = v_layer[head//4]\n",
    "        q_per_token = torch.matmul(layer_embedding_norm, q_layer_head.T)\n",
    "        k_per_token = torch.matmul(layer_embedding_norm, k_layer_head.T)\n",
    "        v_per_token = torch.matmul(layer_embedding_norm, v_layer_head.T)\n",
    "        q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)\n",
    "        q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)\n",
    "        q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis)\n",
    "        q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)\n",
    "        k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)\n",
    "        k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)\n",
    "        k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)\n",
    "        k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)\n",
    "        qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(128)**0.5\n",
    "        mask = torch.full((len(token_embeddings_unnormalized), len(token_embeddings_unnormalized)), float(\"-inf\"))\n",
    "        mask = torch.triu(mask, diagonal=1)\n",
    "        qk_per_token_after_masking = qk_per_token + mask\n",
    "        qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)\n",
    "        qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)\n",
    "        qkv_attention_store.append(qkv_attention)\n",
    "\n",
    "    stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)\n",
    "    w_layer = model[f\"layers.{layer}.attention.wo.weight\"]\n",
    "    embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T)\n",
    "    embedding_after_edit = final_embedding + embedding_delta\n",
    "    embedding_after_edit_normalized = rms_norm(embedding_after_edit, model[f\"layers.{layer}.ffn_norm.weight\"])\n",
    "    w1 = model[f\"layers.{layer}.feed_forward.w1.weight\"]\n",
    "    w2 = model[f\"layers.{layer}.feed_forward.w2.weight\"]\n",
    "    w3 = model[f\"layers.{layer}.feed_forward.w3.weight\"]\n",
    "    output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T)\n",
    "    final_embedding = embedding_after_edit+output_after_feedforward"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 最后，让我们将嵌入解码为标记值\n",
    "<div>\n",
    "    <img src=\"images/finallayer.png\" width=\"600px\"/>\n",
    "</div>\n",
    "我们将使用输出解码器将最终的嵌入转换为一个标记。"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "final_embedding = rms_norm(final_embedding, model[\"norm.weight\"])\n",
    "final_embedding.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# finally, lets decode the embedding into the token value\n",
    "<div>\n",
    "    <img src=\"images/finallayer.png\" width=\"600px\"/>\n",
    "</div>\n",
    "we will use the output decoder to convert the final embedding into a token"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model[\"output.weight\"].shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 我们使用最后一个标记的嵌入来预测下一个值\n",
    "希望在我们的案例中，答案是42 :)\n",
    "注意：42是书籍《银河系漫游指南》中对“生命、宇宙及一切终极问题的答案”的回答。根据这本书，大多数现代大型语言模型都会在这里回答42，这应该验证我们整个代码的正确性！祝我好运 :)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "logits = torch.matmul(final_embedding[-1], model[\"output.weight\"].T)\n",
    "logits.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 模型预测下一个标记是编号2983的标记，这是“42”的标记编号吗？\n",
    "我为你打气！这是最后一个代码单元，希望你玩得开心 :)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "next_token = torch.argmax(logits, dim=-1)\n",
    "next_token"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# lets fucking go\n",
    "<div>\n",
    "    <img src=\"images/42.png\" width=\"600px\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "tokenizer.decode([next_token.item()])"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# thank you, i love you :)\n",
    "\n",
    "This is the end. Hopefully you enjoyed reading it!\n",
    "\n",
    "If you want to support my work\n",
    "\n",
    "1. follow me on twitter https://twitter.com/naklecha \n",
    "2. or, buy me a coffee [https://www.buymeacoffee.com/naklecha](https://www.buymeacoffee.com/naklecha)\n",
    "\n",
    "Honestly, if you made it this far you already made my day :)\n",
    "\n",
    "## what motivates me?\n",
    "\n",
    "My friends and I are on a mission - to make research more accessible!\n",
    "We created a research lab called A10 - [AAAAAAAAAA.org](http://aaaaaaaaaa.org/)\n",
    "\n",
    "A10 twitter - https://twitter.com/aaaaaaaaaaorg\n",
    "\n",
    "our thesis:\n",
    "<div>\n",
    "    <img src=\"images/a10.png\" width=\"600px\"/>\n",
    "</div>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
